{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "outputs": [
    {
     "ename": "ModuleNotFoundError",
     "evalue": "No module named 'tensorflow'",
     "output_type": "error",
     "traceback": [
      "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[1;31mModuleNotFoundError\u001B[0m                       Traceback (most recent call last)",
      "Input \u001B[1;32mIn [1]\u001B[0m, in \u001B[0;36m<cell line: 10>\u001B[1;34m()\u001B[0m\n\u001B[0;32m      8\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01msys\u001B[39;00m\n\u001B[0;32m      9\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mtime\u001B[39;00m\n\u001B[1;32m---> 10\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mtensorflow\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m \u001B[38;5;21;01mtf\u001B[39;00m\n\u001B[0;32m     12\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mtensorflow\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m keras\n\u001B[0;32m     14\u001B[0m \u001B[38;5;28mprint\u001B[39m(tf\u001B[38;5;241m.\u001B[39m__version__)\n",
      "\u001B[1;31mModuleNotFoundError\u001B[0m: No module named 'tensorflow'"
     ]
    }
   ],
   "source": [
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import sklearn\n",
    "import pandas as pd\n",
    "import os\n",
    "import sys\n",
    "import time\n",
    "import tensorflow as tf\n",
    "\n",
    "from tensorflow import keras\n",
    "\n",
    "print(tf.__version__)\n",
    "print(sys.version_info)\n",
    "for module in mpl, np, pd, sklearn, tf, keras:\n",
    "    print(module.__name__, module.__version__)\n",
    "\n",
    "# fashion_mnist图像分类数据集\n",
    "fashion_mnist = keras.datasets.fashion_mnist\n",
    "\n",
    "(x_train_all, y_train_all), (x_test, y_test) = fashion_mnist.load_data()\n",
    "\n",
    "x_valid, x_train = x_train_all[:5000], x_train_all[5000:]\n",
    "y_valid, y_train = y_train_all[:5000], y_train_all[5000:]\n",
    "\n",
    "#得到的是np，训练集，验证集，测试集\n",
    "print(x_train.shape, y_train.shape)\n",
    "print(x_valid.shape, y_valid.shape)\n",
    "print(x_test.shape, y_test.shape)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "28*28  #如果非要对应多少特征，就是784个特征"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "type(x_train)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "y_train[0]"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def show_single_image(img_arr):\n",
    "    plt.imshow(img_arr, cmap=\"binary\")\n",
    "    plt.colorbar()  #旁边显示一个色阶条\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "#依次看一下x_train[0]，x_train[1]，x_train[2]就可以理解下面的循环\n",
    "print(x_train[0])\n",
    "show_single_image(x_train[1])\n",
    "print(x_train[1,0,18])\n",
    "print(y_train[1])"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def show_imgs(n_rows, n_cols, x_data, y_data, class_names):\n",
    "    assert len(x_data) == len(y_data)  #x和y的样本数一致\n",
    "    assert n_rows * n_cols < len(x_data)  #确保打印的图片小于总样本数\n",
    "    plt.figure(figsize = (n_cols * 1.4, n_rows * 1.6))  #宽1.4高1.6\n",
    "    for row in range(n_rows):\n",
    "        for col in range(n_cols):\n",
    "            index = n_cols * row + col\n",
    "            plt.subplot(n_rows, n_cols, index+1)#因为从1开始\n",
    "            plt.imshow(x_data[index], cmap=\"binary\",\n",
    "                       interpolation = 'nearest')\n",
    "            plt.axis('off')#去除坐标系\n",
    "            plt.title(class_names[y_data[index]])\n",
    "    plt.show()\n",
    "#已知的图片类别\n",
    "# lables在这个路径https://github.com/zalandoresearch/fashion-mnist\n",
    "class_names = ['T-shirt', 'Trouser', 'Pullover', 'Dress',\n",
    "               'Coat', 'Sandal', 'Shirt', 'Sneaker',\n",
    "               'Bag', 'Ankle boot']\n",
    "#只是打印了前15个样本\n",
    "show_imgs(4, 5, x_train, y_train, class_names)\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def show_imgs(n_rows, n_cols, x_data, y_data, class_names):\n",
    "    assert len(x_data) == len(y_data)  #x和y的样本数一致\n",
    "    assert n_rows * n_cols < len(x_data)  #确保打印的图片小于总样本数\n",
    "    plt.figure(figsize = (n_cols * 1.4, n_rows * 1.6))  #宽1.4高1.6\n",
    "    for row in range(n_rows):\n",
    "        for col in range(n_cols):\n",
    "            index = n_cols * row + col\n",
    "            plt.subplot(n_rows, n_cols, index+1)#因为从1开始\n",
    "            plt.imshow(x_data[index], cmap=\"binary\",\n",
    "                       interpolation = 'nearest')\n",
    "            plt.axis('off')#去除坐标系\n",
    "            plt.title(class_names[y_data[index]])\n",
    "    plt.show()\n",
    "#已知的图片类别\n",
    "# lables在这个路径https://github.com/zalandoresearch/fashion-mnist\n",
    "class_names = ['T-shirt', 'Trouser', 'Pullover', 'Dress',\n",
    "               'Coat', 'Sandal', 'Shirt', 'Sneaker',\n",
    "               'Bag', 'Ankle boot']\n",
    "#只是打印了前15个样本\n",
    "show_imgs(4, 5, x_train, y_train, class_names)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "model.layers  #总计有4层"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "#可以来算一下参数个数\n",
    "model.summary()"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "model.variables  #模型中自己训练的参数"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "#均匀分布\n",
    "np.sqrt(6/300)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "model.variables[0].numpy().max()"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "np.sqrt(6/100)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "model.variables[2].numpy().max()"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# [None, 784] * W + b -> [None, 300] W.shape [784, 300], b = [300]\n",
    "#W是矩阵，b是一个偏置，是一个向量"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "history = model.fit(x_train, y_train, epochs=20,\n",
    "                    validation_data=(x_valid, y_valid))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "type(history)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "#十次历史的信息\n",
    "history.history"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "#画线要注意的是损失是不一定在零到1之间的\n",
    "def plot_learning_curves(history):\n",
    "    pd.DataFrame(history.history).plot(figsize=(8, 5))\n",
    "    plt.grid(True)\n",
    "    plt.gca().set_ylim(0, 1)\n",
    "    plt.show()\n",
    "\n",
    "plot_learning_curves(history)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "model.evaluate(x_test, y_test, verbose=0)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "#计算标签和预测之间的crossentropy(交叉熵)损失\n",
    "cce = keras.losses.SparseCategoricalCrossentropy()\n",
    "loss = cce(\n",
    "  [0, 1, 2],\n",
    "  [[.9, .05, .05], [.05, .89, .06], [.05, .01, .94]])\n",
    "print('Loss: ', loss.numpy())"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}